package com.jstarcraft.ai.jsat.datatransform.kernel;

import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.datatransform.DataTransformBase;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.distributions.kernels.RBFKernel;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Matrix;
import com.jstarcraft.ai.jsat.linear.RandomMatrix;
import com.jstarcraft.ai.jsat.linear.RandomVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * An Implementation of Random Fourier Features for the {@link RBFKernel}. It
 * transforms the numerical variables of a feature space to form a new feature
 * space where the dot product between features approximates the RBF Kernel
 * product. <br>
 * <br>
 * See: Rahimi, A.,&amp;Recht, B. (2007). <i>Random Features for Large-Scale
 * Kernel Machines</i>. Neural Information Processing Systems. Retrieved from
 * <a href=
 * "http://seattle.intel-research.net/pubs/rahimi-recht-random-features.pdf">
 * here</a>
 * 
 * @author Edward Raff
 */
public class RFF_RBF extends DataTransformBase {

    private static final long serialVersionUID = -3478216020648280477L;
    private Matrix transform;
    private Vec offsets;
    private double sigma;
    private int dim;
    private boolean inMemory;

    /**
     * Creates a new RFF RBF object that will use an transformed feature space with
     * a dimensionality of 512. This constructor should be used with a parameter
     * search to find a good value for {@link #setSigma(double) sigma}
     */
    public RFF_RBF() {
        this(1.0);
    }

    /**
     * Creates a new RFF RBF object that will use an transformed feature space with
     * a dimensionality of 512.
     *
     * @param sigma the positive sigma value for the {@link RBFKernel}
     */
    public RFF_RBF(double sigma) {
        this(sigma, 512);
    }

    /**
     * Creates a new RFF RBF object
     *
     * @param sigma the positive sigma value for the {@link RBFKernel}
     * @param dim   the new feature size dimension to project into.
     */
    public RFF_RBF(double sigma, int dim) {
        this(sigma, dim, true);
    }

    /**
     * Creates a new RFF RBF object
     *
     * @param sigma    the positive sigma value for the {@link RBFKernel}
     * @param dim      the new feature size dimension to project into.
     * @param inMemory {@code true} if the internal matrix should be stored in
     *                 memory. If {@code false}, the memory will be re-computed as
     *                 needed, increasing computation cost but uses no extra memory.
     */
    public RFF_RBF(double sigma, int dim, boolean inMemory) {
        setSigma(sigma);
        setDimensions(dim);
        setInMemory(inMemory);
    }

    /**
     * Creates a new RFF RBF object
     * 
     * @param featurSize the number of numeric features in the original feature
     *                   space
     * @param sigma      the positive sigma value for the {@link RBFKernel}
     * @param dim        the new feature size dimension to project into.
     * @param rand       the source of randomness to initialize internal state
     * @param inMemory   {@code true} if the internal matrix should be stored in
     *                   memory. If {@code false}, the memory will be re-computed as
     *                   needed, increasing computation cost but uses no extra
     *                   memory.
     */
    public RFF_RBF(int featurSize, double sigma, int dim, Random rand, boolean inMemory) {
        this(sigma, dim, inMemory);
        if (featurSize <= 0)
            throw new IllegalArgumentException("The number of numeric features must be positive, not " + featurSize);
        if (sigma <= 0 || Double.isInfinite(sigma) || Double.isNaN(sigma))
            throw new IllegalArgumentException("The sigma parameter must be positive, not " + sigma);
        if (dim <= 1)
            throw new IllegalArgumentException("The target dimension must be positive, not " + dim);
        transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5 / (sigma * sigma)), featurSize, dim, rand.nextLong());
        offsets = new RandomVectorRFF_RBF(dim, rand.nextLong());

        if (inMemory) {
            transform = transform.add(0.0);// will copy into a new mutable and add nothing
            offsets = new DenseVector(offsets);
        }
    }

    @Override
    public void fit(DataSet data) {
        int featurSize = data.getNumNumericalVars();
        Random rand = RandomUtil.getRandom();
        transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5 / (sigma * sigma)), featurSize, dim, rand.nextLong());
        offsets = new RandomVectorRFF_RBF(dim, rand.nextLong());

        if (inMemory) {
            transform = transform.add(0.0);// will copy into a new mutable and add nothing
            offsets = new DenseVector(offsets);
        }
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to copy
     */
    protected RFF_RBF(RFF_RBF toCopy) {
        if (toCopy.transform != null)
            this.transform = toCopy.transform.clone();
        if (toCopy.offsets != null)
            this.offsets = toCopy.offsets.clone();
        this.dim = toCopy.dim;
        this.inMemory = toCopy.inMemory;
        this.sigma = toCopy.sigma;
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec oldX = dp.getNumericalValues();
        Vec newX = oldX.multiply(transform);

        final double coef = Math.sqrt(2.0 / transform.cols());
        for (int i = 0; i < newX.length(); i++)
            newX.set(i, Math.cos(newX.get(i) + offsets.get(i)) * coef);

        return new DataPoint(newX, dp.getCategoricalValues(), dp.getCategoricalData());
    }

    @Override
    public RFF_RBF clone() {
        return new RFF_RBF(this);
    }

    private static class RandomMatrixRFF_RBF extends RandomMatrix {

        private static final long serialVersionUID = 4702514384718636893L;
        private double coef;

        public RandomMatrixRFF_RBF(double coef, int rows, int cols, long seedMult) {
            super(rows, cols, seedMult);
            this.coef = coef;
        }

        @Override
        protected double getVal(Random rand) {
            return coef * rand.nextGaussian();
        }
    }

    private static class RandomVectorRFF_RBF extends RandomVector {

        private static final long serialVersionUID = -6132378281909907937L;

        public RandomVectorRFF_RBF(int length, long seedMult) {
            super(length, seedMult);
        }

        @Override
        protected double getVal(Random rand) {
            return rand.nextDouble() * 2 * Math.PI;
        }

        @Override
        public Vec clone() {
            return this;
        }

    }

    /**
     * Sets whether or not the transform matrix is stored explicitly in memory or
     * not. Explicit storage is often faster, but can be prohibitive for large
     * feature sizes
     * 
     * @param inMemory {@code true} to explicitly store the transform matrix,
     *                 {@code false} to re-create it on the fly as needed
     */
    public void setInMemory(boolean inMemory) {
        this.inMemory = inMemory;
    }

    /**
     * 
     * @return {@code true} if this object will explicitly store the transform
     *         matrix, {@code false} to re-create it on the fly as needed
     */
    public boolean isInMemory() {
        return inMemory;
    }

    /**
     * Sets the number of dimensions in the new approximate space to use. This will
     * be the number of numeric features in the transformed data, and larger values
     * increase the accuracy of the approximation.
     *
     * @param dimensions
     */
    public void setDimensions(int dimensions) {
        if (dimensions < 1)
            throw new ArithmeticException("Number of dimensions must be a positive value, not " + dimensions);
        this.dim = dimensions;
    }

    /**
     * Returns the number of dimensions that will be used in the projected space
     *
     * @return the number of dimensions that will be used in the projected space
     */
    public int getDimensions() {
        return dim;
    }

    /**
     * Sets the &sigma; parameter of the RBF kernel that is being approximated.
     *
     * @param sigma the positive value to use for &sigma;
     * @see RBFKernel#setSigma(double)
     */
    public void setSigma(double sigma) {
        if (sigma <= 0.0 || Double.isInfinite(sigma) || Double.isNaN(sigma))
            throw new IllegalArgumentException("Sigma must be a positive value, not " + sigma);
        this.sigma = sigma;
    }

    /**
     * Returns the &sigma; value used for the RBF kernel approximation.
     *
     * @return the &sigma; value used for the RBF kernel approximation.
     */
    public double getSigma() {
        return sigma;
    }

    /**
     * Guess the distribution to use for the kernel width term
     * {@link #setSigma(double) &sigma;} in the RBF kernel being approximated.
     *
     * @param d the data set to get the guess for
     * @return the guess for the &sigma; parameter in the RBF Kernel
     */
    public Distribution guessSigma(DataSet d) {
        return RBFKernel.guessSigma(d);
    }
}
